import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; sns.set_style('darkgrid')
color_palette = sns.color_palette()

# mode = 'train_accs'
mode = 'test_accs'

# Baseline
model_mean = 0
model_std = 0
data = np.zeros([10])
complete_fails = 0
for seed in range(1,11):
  results = torch.load('cloud_logs/seed%d-nonTrue-dropoutFalse-actReLU-outsoftmax-optAdam-slr0.002000-blr0.002000-Default.dict' % seed)
  data[seed-1] = results[mode][-1]
  if data[seed-1] < 12.0:
    complete_fails += 1
model_mean = data.mean()
model_std = data.std(ddof=1)/np.sqrt(10)
print("Baseline")
print(model_mean)
print(model_std)
print(complete_fails)

# Experiment 1

# Effect of m
m_mean = np.zeros([3,4])
m_std = np.zeros([3,4])
m_list = [20,40,60,80]
num_complete_fails = []
for j, aux in enumerate([50,100,200]):
  for i, m in enumerate(m_list):
    data = np.zeros([10])
    complete_fails = 0
    for seed in range(1,11):
      results = torch.load('cloud_logs/seed%d-m%d-aux%d-dropoutFalse-actReLU-outsoftmax-optAdam-slr0.002000-blr0.000010-Default.dict' % (seed,m,aux))
      data[seed-1] = results[mode][-1]
      if data[seed-1] < 12.0:
        complete_fails += 1
    num_complete_fails.append(complete_fails)
    m_mean[j,i] = data.mean()
    m_std[j,i] = data.std(ddof=1)/np.sqrt(10)

print("Effect of m")
print(m_list)
print(m_mean)
print(m_std)
print(num_complete_fails)

# Effect of aux
aux_mean = np.zeros([3,4])
aux_std = np.zeros([3,4])
aux_list = [200, 400, 600, 800]
num_complete_fails = []
for j, m in enumerate([100,200,400]):
  for i, aux in enumerate(aux_list):
    data = np.zeros([10])
    complete_fails = 0
    for seed in range(1,11):
      results = torch.load('cloud_logs/seed%d-m%d-aux%d-dropoutFalse-actReLU-outsoftmax-optAdam-slr0.002000-blr0.000010-Default.dict' % (seed,m,aux))
      data[seed-1] = results[mode][-1]
      if data[seed-1] < 12.0:
        complete_fails += 1
    num_complete_fails.append(complete_fails)
    aux_mean[j,i] = data.mean()
    aux_std[j,i] = data.std(ddof=1)/np.sqrt(10)
print("Effect of aux")
print(aux_list)
print(aux_mean)
print(aux_std)
print(num_complete_fails)

plt.figure(figsize=(12,5))
plt.subplot(121)
ax = plt.gca()
ax.fill_between(m_list, m_mean[0]-m_std[0], m_mean[0]+m_std[0], alpha=0.4, color=color_palette[0])
sns.lineplot(m_list, m_mean[0], color=color_palette[0], label='aux=50')
ax.fill_between(m_list, m_mean[1]-m_std[1], m_mean[1]+m_std[1], alpha=0.4, color=color_palette[2])
sns.lineplot(m_list, m_mean[1], color=color_palette[2], label='aux=100')
ax.fill_between(m_list, m_mean[2]-m_std[2], m_mean[2]+m_std[2], alpha=0.4, color=color_palette[4])
sns.lineplot(m_list, m_mean[2], color=color_palette[4], label='aux=200')
plt.xticks(m_list)
plt.ylabel('Test Accuracy')
plt.xlabel('Number of Clauses (m)')
plt.title('Effect of m')
plt.subplot(122)
ax = plt.gca()
ax.fill_between(aux_list, aux_mean[0]-aux_std[0], aux_mean[0]+aux_std[0], alpha=0.4, color=color_palette[1])
sns.lineplot(aux_list, aux_mean[0], color=color_palette[1], label='m=100')
ax.fill_between(aux_list, aux_mean[1]-aux_std[1], aux_mean[1]+aux_std[1], alpha=0.4, color=color_palette[3])
sns.lineplot(aux_list, aux_mean[1], color=color_palette[3], label='m=200')
ax.fill_between(aux_list, aux_mean[2]-aux_std[2], aux_mean[2]+aux_std[2], alpha=0.4, color=color_palette[5])
sns.lineplot(aux_list, aux_mean[2], color=color_palette[5], label='m=400')
plt.xticks(aux_list)
plt.ylabel('Test Accuracy')
plt.xlabel('Number of Auxiliary Variables (aux)')
plt.title('Effect of aux')
plt.tight_layout()
# plt.show()
plt.savefig('m_aux.pdf')
plt.clf()

# Experiment 2

# Effect of Learning Rates
for slr in [1e-3,1e-4,1e-5]:
  for blr in [1e-3,1e-4,1e-5]:
    complete_fails = 0
    data = np.zeros([10])
    for seed in range(1,11):
      results = torch.load('cloud_logs/seed%d-m200-aux100-dropoutFalse-actReLU-outsoftmax-optAdam-slr%f-blr%f-Default.dict' % (seed,slr,blr))
      data[seed-1] = results[mode][-1]
      if data[seed-1] < 12.0:
        complete_fails += 1
    mean = data.mean()
    std = data.std(ddof=1)/np.sqrt(10)
    print("SATNet lr:%f Backbone lr:%f Mean:%.1f Std:%.1f Complete Fails:%d" % (slr,blr,mean,std,complete_fails))

# Experiment 3

# SGD vs Adam
data = np.zeros([10])
complete_fails = 0
for seed in range(1,11):
  results = torch.load('cloud_logs/seed%d-m200-aux100-dropoutFalse-actReLU-outsoftmax-optAdam-slr0.001000-blr0.000010-Default.dict' % seed)
  data[seed-1] = results[mode][-1]
  if data[seed-1] < 12.0:
    complete_fails += 1
adam_mean = data.mean()
adam_std = data.std(ddof=1)/np.sqrt(10)
data = np.zeros([10])
complete_fails = 0
for seed in range(1,11):
  results = torch.load('cloud_logs/seed%d-m200-aux100-dropoutFalse-actReLU-outsoftmax-optSGD-slr0.001000-blr0.100000-Default.dict' % seed)
  data[seed-1] = results[mode][-1]
  if data[seed-1] < 12.0:
    complete_fails += 1
sgd_mean = data.mean()
sgd_std = data.std(ddof=1)/np.sqrt(10)
print("Adam Mean:%.1f Std:%.1f Complete Fails:%d" % (adam_mean,adam_std,complete_fails))
print("SGD Mean:%.1f Std:%.1f Complete Fails:%d" % (sgd_mean,sgd_std,complete_fails))

# Experiment 4

# Effect of Architecture Size and Softmax vs Sigmoid
model_mean = []
model_std = []
num_complete_fails = []
params = []
for out in ['softmax','sigmoid']:
  for backbone in ['LeNet', 'Default', 'ResNet18']:
    data = np.zeros([10])
    complete_fails = 0
    for seed in range(1,11):
      results = torch.load('cloud_logs/seed%d-m200-aux100-dropoutFalse-actReLU-out%s-optSGD-slr0.001000-blr0.100000-%s.dict' % (seed,out,backbone))
      data[seed-1] = results[mode][-1]
      if data[seed-1] < 12.0:
        complete_fails += 1
      if seed == 1:
        params.append(results['num_params'])
    num_complete_fails.append(complete_fails)
    model_mean.append(data.mean())
    model_std.append(data.std(ddof=1)/np.sqrt(10))
print("Effect of Architecture Size and Softmax vs Sigmoid")
print(model_mean)
print(model_std)
print(num_complete_fails)
print(params)